{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cbbc1fe3-bff1-4631-bf35-342e19c54cc0",
   "metadata": {},
   "source": [
    "<table style=\"width:100%\">\n",
    "<tr>\n",
    "<td style=\"vertical-align:middle; text-align:left;\">\n",
    "<font size=\"2\">\n",
    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
    "</font>\n",
    "</td>\n",
    "<td style=\"vertical-align:middle; text-align:left;\">\n",
    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
    "</td>\n",
    "</tr>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b022374-e3f6-4437-b86f-e6f8f94cbebc",
   "metadata": {},
   "source": [
    "# Extending the Tiktoken BPE Tokenizer with New Tokens"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bcd624b1-2060-49af-bbf6-40517a58c128",
   "metadata": {},
   "source": [
    "- This notebook explains how we can extend an existing BPE tokenizer; specifically, we will focus on how to do it for the popular [tiktoken](https://github.com/openai/tiktoken) implementation\n",
    "- For a general introduction to tokenization, please refer to [Chapter 2](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch02/01_main-chapter-code/ch02.ipynb) and the BPE from Scratch [link] tutorial\n",
    "- For example, suppose we have a GPT-2 tokenizer and want to encode the following text:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "798d4355-a146-48a8-a1a5-c5cec91edf2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[15496, 11, 2011, 3791, 30642, 62, 16, 318, 257, 649, 11241, 13, 220, 50256]\n"
     ]
    }
   ],
   "source": [
    "import tiktoken\n",
    "\n",
    "base_tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
    "sample_text = \"Hello, MyNewToken_1 is a new token. <|endoftext|>\"\n",
    "\n",
    "token_ids = base_tokenizer.encode(sample_text, allowed_special={\"<|endoftext|>\"})\n",
    "print(token_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b09b19b-772d-4449-971b-8ab052ee726d",
   "metadata": {},
   "source": [
    "- Iterating over each token ID can give us a better understanding of how the token IDs are decoded via the vocabulary:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "21fd634b-bb4c-4ba3-8b69-9322b727bf58",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "15496 -> Hello\n",
      "11 -> ,\n",
      "2011 ->  My\n",
      "3791 -> New\n",
      "30642 -> Token\n",
      "62 -> _\n",
      "16 -> 1\n",
      "318 ->  is\n",
      "257 ->  a\n",
      "649 ->  new\n",
      "11241 ->  token\n",
      "13 -> .\n",
      "220 ->  \n",
      "50256 -> <|endoftext|>\n"
     ]
    }
   ],
   "source": [
    "for token_id in token_ids:\n",
    "    print(f\"{token_id} -> {base_tokenizer.decode([token_id])}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd5b1b9b-b1a9-489e-9711-c15a8e081813",
   "metadata": {},
   "source": [
    "- As we can see above, the `\"MyNewToken_1\"` is broken down into 5 individual subword tokens -- this is normal behavior for BPE when handling unknown words\n",
    "- However, suppose that it's a special token that we want to encode as a single token, similar to some of the other words or `\"<|endoftext|>\"`; this notebook explains how"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65f62ab6-df96-4f88-ab9a-37702cd30f5f",
   "metadata": {},
   "source": [
    "&nbsp;\n",
    "## 1. Adding special tokens"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4379fdb-57ba-4a75-9183-0aee0836c391",
   "metadata": {},
   "source": [
    "- Note that we have to add new tokens as special tokens; the reason is that we don't have the \"merges\" for the new tokens that are created during the tokenizer training process -- even if we had them, it would be very challenging to incorporate them without breaking the existing tokenization scheme (see the BPE from scratch notebook [link] to understand the \"merges\")\n",
    "- Suppose we want to add 2 new tokens:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "265f1bba-c478-497d-b7fc-f4bd191b7d55",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define custom tokens and their token IDs\n",
    "custom_tokens = [\"MyNewToken_1\", \"MyNewToken_2\"]\n",
    "custom_token_ids = {\n",
    "    token: base_tokenizer.n_vocab + i for i, token in enumerate(custom_tokens)\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c6f3d98-1ab6-43cf-9ae2-2bf53860f99e",
   "metadata": {},
   "source": [
    "- Next, we create a custom `Encoding` object that holds our special tokens as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1f519852-59ea-4069-a8c7-0f647bfaea09",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a new Encoding object with extended tokens\n",
    "extended_tokenizer = tiktoken.Encoding(\n",
    "    name=\"gpt2_custom\",\n",
    "    pat_str=base_tokenizer._pat_str,\n",
    "    mergeable_ranks=base_tokenizer._mergeable_ranks,\n",
    "    special_tokens={**base_tokenizer._special_tokens, **custom_token_ids},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90af6cfa-e0cc-4c80-89dc-3a824e7bdeb2",
   "metadata": {},
   "source": [
    "- That's it, we can now check that it can encode the sample text:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "153e8e1d-c4cb-41ff-9c55-1701e9bcae1c",
   "metadata": {},
   "source": [
    "- As we can see, the new tokens `50257` and `50258` are now encoded in the output:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "eccc78a4-1fd4-47ba-a114-83ee0a3aec31",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[36674, 2420, 351, 220, 50257, 290, 220, 50258, 13, 220, 50256]\n"
     ]
    }
   ],
   "source": [
    "special_tokens_set = set(custom_tokens) | {\"<|endoftext|>\"}\n",
    "\n",
    "token_ids = extended_tokenizer.encode(\n",
    "    \"Sample text with MyNewToken_1 and MyNewToken_2. <|endoftext|>\",\n",
    "    allowed_special=special_tokens_set\n",
    ")\n",
    "print(token_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc0547c1-bbb5-4915-8cf4-caaebcf922eb",
   "metadata": {},
   "source": [
    "- Again, we can also look at it on a per-token level:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7583eff9-b10d-4e3d-802c-f0464e1ef030",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "36674 -> Sample\n",
      "2420 ->  text\n",
      "351 ->  with\n",
      "220 ->  \n",
      "50257 -> MyNewToken_1\n",
      "290 ->  and\n",
      "220 ->  \n",
      "50258 -> MyNewToken_2\n",
      "13 -> .\n",
      "220 ->  \n",
      "50256 -> <|endoftext|>\n"
     ]
    }
   ],
   "source": [
    "for token_id in token_ids:\n",
    "    print(f\"{token_id} -> {extended_tokenizer.decode([token_id])}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17f0764e-e5a9-4226-a384-18c11bd5fec3",
   "metadata": {},
   "source": [
    "- As we can see above, we have successfully updated the tokenizer\n",
    "- However, to use it with a pretrained LLM, we also have to update the embedding and output layers of the LLM, which is discussed in the next section"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ec7f98d-8f09-4386-83f0-9bec68ef7f66",
   "metadata": {},
   "source": [
    "&nbsp;\n",
    "## 2. Updating a pretrained LLM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8a4f68b-04e9-4524-8df4-8718c7b566f2",
   "metadata": {},
   "source": [
    "- In this section, we will take a look at how we have to update an existing pretrained LLM after updating the tokenizer\n",
    "- For this, we are using the original pretrained GPT-2 model that is used in the main book"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a9b252e-1d1d-4ddf-b9f3-95bd6ba505a9",
   "metadata": {},
   "source": [
    "&nbsp;\n",
    "### 2.1 Loading a pretrained GPT model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ded29b4e-9b39-4191-b61c-29d6b2360bae",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "checkpoint: 100%|███████████████████████████| 77.0/77.0 [00:00<00:00, 34.4kiB/s]\n",
      "encoder.json: 100%|███████████████████████| 1.04M/1.04M [00:00<00:00, 4.78MiB/s]\n",
      "hparams.json: 100%|█████████████████████████| 90.0/90.0 [00:00<00:00, 24.7kiB/s]\n",
      "model.ckpt.data-00000-of-00001: 100%|███████| 498M/498M [00:33<00:00, 14.7MiB/s]\n",
      "model.ckpt.index: 100%|███████████████████| 5.21k/5.21k [00:00<00:00, 1.05MiB/s]\n",
      "model.ckpt.meta: 100%|██████████████████████| 471k/471k [00:00<00:00, 2.33MiB/s]\n",
      "vocab.bpe: 100%|████████████████████████████| 456k/456k [00:00<00:00, 2.45MiB/s]\n"
     ]
    }
   ],
   "source": [
    "from llms_from_scratch.ch05 import download_and_load_gpt2\n",
    "# For llms_from_scratch installation instructions, see:\n",
    "# https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg\n",
    "\n",
    "settings, params = download_and_load_gpt2(model_size=\"124M\", models_dir=\"gpt2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "93dc0d8e-b549-415b-840e-a00023bddcf9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llms_from_scratch.ch04 import GPTModel\n",
    "# For llms_from_scratch installation instructions, see:\n",
    "# https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg\n",
    "\n",
    "GPT_CONFIG_124M = {\n",
    "    \"vocab_size\": 50257,   # Vocabulary size\n",
    "    \"context_length\": 256, # Shortened context length (orig: 1024)\n",
    "    \"emb_dim\": 768,        # Embedding dimension\n",
    "    \"n_heads\": 12,         # Number of attention heads\n",
    "    \"n_layers\": 12,        # Number of layers\n",
    "    \"drop_rate\": 0.1,      # Dropout rate\n",
    "    \"qkv_bias\": False      # Query-key-value bias\n",
    "}\n",
    "\n",
    "# Define model configurations in a dictionary for compactness\n",
    "model_configs = {\n",
    "    \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
    "    \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
    "    \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
    "    \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
    "}\n",
    "\n",
    "# Copy the base configuration and update with specific model settings\n",
    "model_name = \"gpt2-small (124M)\"  # Example model name\n",
    "NEW_CONFIG = GPT_CONFIG_124M.copy()\n",
    "NEW_CONFIG.update(model_configs[model_name])\n",
    "NEW_CONFIG.update({\"context_length\": 1024, \"qkv_bias\": True})\n",
    "\n",
    "gpt = GPTModel(NEW_CONFIG)\n",
    "gpt.eval();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83f898c0-18f4-49ce-9b1f-3203a277b29e",
   "metadata": {},
   "source": [
    "### 2.2 Using the pretrained GPT model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5a1f5e1-e806-4c60-abaa-42ae8564908c",
   "metadata": {},
   "source": [
    "- Next, consider our sample text below, which we tokenize using the original and the new tokenizer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9a88017d-cc8f-4ba1-bba9-38161a30f673",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sample_text = \"Sample text with MyNewToken_1 and MyNewToken_2. <|endoftext|>\"\n",
    "\n",
    "original_token_ids = base_tokenizer.encode(\n",
    "    sample_text, allowed_special={\"<|endoftext|>\"}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1ee01bc3-ca24-497b-b540-3d13c52c29ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_token_ids = extended_tokenizer.encode(\n",
    "    \"Sample text with MyNewToken_1 and MyNewToken_2. <|endoftext|>\",\n",
    "    allowed_special=special_tokens_set\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1143106b-68fe-4234-98ad-eaff420a4d08",
   "metadata": {},
   "source": [
    "- Now, let's feed the original token IDs to the GPT model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6b06827f-b411-42cc-b978-5c1d568a3200",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ 0.2204,  0.8901,  1.0138,  ...,  0.2585, -0.9192, -0.2298],\n",
      "         [ 0.6745, -0.0726,  0.8218,  ..., -0.1768, -0.4217,  0.0703],\n",
      "         [-0.2009,  0.0814,  0.2417,  ...,  0.3166,  0.3629,  1.3400],\n",
      "         ...,\n",
      "         [ 0.1137, -0.1258,  2.0193,  ..., -0.0314, -0.4288, -0.1487],\n",
      "         [-1.1983, -0.2050, -0.1337,  ..., -0.0849, -0.4863, -0.1076],\n",
      "         [-1.0675, -0.5905,  0.2873,  ..., -0.0979, -0.8713,  0.8415]]])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "with torch.no_grad():\n",
    "    out = gpt(torch.tensor([original_token_ids]))\n",
    "\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "082c7a78-35a8-473e-a08d-b099a6348a74",
   "metadata": {},
   "source": [
    "- As we can see above, this works without problems (note that the code shows the raw output without converting the outputs back into text for simplicity; for more details on that, please check out the `generate` function in Chapter 5 [link] section 5.3.3"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "628265b5-3dde-44e7-bde2-8fc594a2547d",
   "metadata": {},
   "source": [
    "- What happens if we try the same on the token IDs generated by the updated tokenizer now?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9796ad09-787c-4c25-a7f5-6d1dfe048ac3",
   "metadata": {},
   "source": [
    "```python\n",
    "with torch.no_grad():\n",
    "    gpt(torch.tensor([new_token_ids]))\n",
    "\n",
    "print(out)\n",
    "\n",
    "...\n",
    "# IndexError: index out of range in self\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77d00244-7e40-4de0-942e-e15cdd8e3b18",
   "metadata": {},
   "source": [
    "- As we can see, this results in an index error\n",
    "- The reason is that the GPT model expects a fixed vocabulary size via its input embedding layer and its output layer:\n",
    "\n",
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/extend-tiktoken/gpt-updates.webp\" width=\"400px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dec38b24-c845-4090-96a4-0d3c4ec241d6",
   "metadata": {},
   "source": [
    "&nbsp;\n",
    "### 2.3 Updating the embedding layer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1328726-8297-4162-878b-a5daff7de742",
   "metadata": {},
   "source": [
    "- Let's start with updating the embedding layer\n",
    "- First, notice that the embedding layer has 50,257 entries, which corresponds to the vocabulary size:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "23ecab6e-1232-47c7-a318-042f90e1dff3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Embedding(50257, 768)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gpt.tok_emb"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d760c683-d082-470a-bff8-5a08b30d3b61",
   "metadata": {},
   "source": [
    "- We want to extend this embedding layer by adding 2 more entries\n",
    "- In short, we create a new embedding layer with a bigger size, and then we copy over the old embedding layer values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4ec5c48e-c6fe-4e84-b290-04bd4da9483f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embedding(50259, 768)\n"
     ]
    }
   ],
   "source": [
    "num_tokens, emb_size = gpt.tok_emb.weight.shape\n",
    "new_num_tokens = num_tokens + 2\n",
    "\n",
    "# Create a new embedding layer\n",
    "new_embedding = torch.nn.Embedding(new_num_tokens, emb_size)\n",
    "\n",
    "# Copy weights from the old embedding layer\n",
    "new_embedding.weight.data[:num_tokens] = gpt.tok_emb.weight.data\n",
    "\n",
    "# Replace the old embedding layer with the new one in the model\n",
    "gpt.tok_emb = new_embedding\n",
    "\n",
    "print(gpt.tok_emb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63954928-31a5-4e7e-9688-2e0c156b7302",
   "metadata": {},
   "source": [
    "- As we can see above, we now have an increased embedding layer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e68bea5-255b-47bb-b352-09ea9539bc25",
   "metadata": {},
   "source": [
    "&nbsp;\n",
    "### 2.4 Updating the output layer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90a4a519-bf0f-4502-912d-ef0ac7a9deab",
   "metadata": {},
   "source": [
    "- Next, we have to extend the output layer, which has 50,257 output features corresponding to the vocabulary size similar to the embedding layer (by the way, you may find the bonus material, which discusses the similarity between Linear and Embedding layers in PyTorch, useful)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "6105922f-d889-423e-bbcc-bc49156d78df",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Linear(in_features=768, out_features=50257, bias=False)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gpt.out_head"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29f1ff24-9c00-40f6-a94f-82d03aaf0890",
   "metadata": {},
   "source": [
    "- The procedure for extending the output layer is similar to extending the embedding layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "354589db-b148-4dae-8068-62132e3fb38e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear(in_features=768, out_features=50259, bias=True)\n"
     ]
    }
   ],
   "source": [
    "original_out_features, original_in_features = gpt.out_head.weight.shape\n",
    "\n",
    "# Define the new number of output features (e.g., adding 2 new tokens)\n",
    "new_out_features = original_out_features + 2\n",
    "\n",
    "# Create a new linear layer with the extended output size\n",
    "new_linear = torch.nn.Linear(original_in_features, new_out_features)\n",
    "\n",
    "# Copy the weights and biases from the original linear layer\n",
    "with torch.no_grad():\n",
    "    new_linear.weight[:original_out_features] = gpt.out_head.weight\n",
    "    if gpt.out_head.bias is not None:\n",
    "        new_linear.bias[:original_out_features] = gpt.out_head.bias\n",
    "\n",
    "# Replace the original linear layer with the new one\n",
    "gpt.out_head = new_linear\n",
    "\n",
    "print(gpt.out_head)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df5d2205-1fae-4a4f-a7bd-fa8fc37eeec2",
   "metadata": {},
   "source": [
    "- Let's try this updated model on the original token IDs first:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "df604bbc-6c13-4792-8ba8-ecb692117c25",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ 0.2267,  0.9132,  1.0494,  ..., -0.2330, -0.3008, -1.1458],\n",
      "         [ 0.6808, -0.0495,  0.8574,  ...,  0.0671,  0.5572, -0.7873],\n",
      "         [-0.1947,  0.1045,  0.2773,  ...,  1.3368,  0.8479, -0.9660],\n",
      "         ...,\n",
      "         [ 0.1200, -0.1027,  2.0549,  ..., -0.1519, -0.2096,  0.5651],\n",
      "         [-1.1920, -0.1819, -0.0981,  ..., -0.1108,  0.8435, -0.3771],\n",
      "         [-1.0612, -0.5674,  0.3229,  ...,  0.8383, -0.7121, -0.4850]]])\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    output = gpt(torch.tensor([original_token_ids]))\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d80717e-50e6-4927-8129-0aadfa2628f5",
   "metadata": {},
   "source": [
    "- Next, let's try it on the updated tokens:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "75f11ec9-bdd2-440f-b8c8-6646b75891c6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ 0.2267,  0.9132,  1.0494,  ..., -0.2330, -0.3008, -1.1458],\n",
      "         [ 0.6808, -0.0495,  0.8574,  ...,  0.0671,  0.5572, -0.7873],\n",
      "         [-0.1947,  0.1045,  0.2773,  ...,  1.3368,  0.8479, -0.9660],\n",
      "         ...,\n",
      "         [-0.0656, -1.2451,  0.7957,  ..., -1.2124,  0.1044,  0.5088],\n",
      "         [-1.1561, -0.7380, -0.0645,  ..., -0.4373,  1.1401, -0.3903],\n",
      "         [-0.8961, -0.6437, -0.1667,  ...,  0.5663, -0.5862, -0.4020]]])\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    output = gpt(torch.tensor([new_token_ids]))\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d88a1bba-db01-4090-97e4-25dfc23ed54c",
   "metadata": {},
   "source": [
    "- As we can see, the model works on the extended token set\n",
    "- In practice, we want to now finetune (or continually pretrain) the model (specifically the new embedding and output layers) on data containing the new tokens"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6de573ad-0338-40d9-9dad-de60ae349c4f",
   "metadata": {},
   "source": [
    "**A note about weight tying**\n",
    "\n",
    "- If the model uses weight tying, which means that the embedding layer and output layer share the same weights, similar to Llama 3 [link], updating the output layer is much simpler\n",
    "- In this case, we can simply copy over the weights from the embedding layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "4cbc5f51-c7a8-49d0-b87f-d3d87510953b",
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt.out_head.weight = gpt.tok_emb.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "d0d553a8-edff-40f0-bdc4-dff900e16caf",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    output = gpt(torch.tensor([new_token_ids]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
